import os
import json
from common import *

# 遍历文件夹
def	getJsonFilesByPath(dir):
	list = []
	for root,dirs,files in os.walk(dir):
		# root 表示当前正在访问的文件夹路径
		# dirs 表示该文件夹下的子目录名list
		# files 表示该文件夹下的文件list
		# 遍历文件
		for f in files:
			if f.lower().index('.json') > 0:
				list.append(os.path.join(root, f))
	return list
			
#导出JSON格式
def importJson(files=[]):
	for file in files:
		f=None
		print('import json file: {}'.format(file))
		try:
			f = open('{}'.format(file), 'r+')
		except OSError:
			print('open file:{} error!'.format(table))
		else:
			str = f.read()
			json_boby = json.loads(str)
			dbClient().write_points(json_boby)
			f.close()

#导出SQL格式
def importSql(files=[]):
	for table in files:
		result=dbClient().query('select * from {}'.format(table))
		f=None
		try:
			f = open('{}_{}.text'.format(table, currentTimeStr()), 'w+')
		except OSError:
			print('write table:{} error!'.format(table))
		else:
			points = result.get_points(measurement=table)
			print(points)
			print('sql')
			f.close()

#导出CSV格式
def importCsv():
	for table in files:
		result=dbClient().query('select * from {}'.format(table))
		f=None
		try:
			f = open('{}_{}.text'.format(table, currentTimeStr()), 'w+')
		except OSError:
			print('write table:{} error!'.format(table))
		else:
			points = result.get_points(measurement=table)
			print(points)
			print('csv')
			f.close()

#导出数据库
def restoreDB():
	dir = inputDir()
	files = []
	if dbTables().strip() == '':
		files = getJsonFilesByPath(dir)
	else:
		tables = dbTables().split(',')
		for f in tables:
			files.append('{}{}'.format(dir, f))
			
	#print(files)
	format = ioFormat()
	print('format:{}'.format(format))
	if format == 'json':
		importJson(files)
		print('restore db end!')
		'''
	elif format == 'sql':
		importSql(files)
	elif format == 'csv':
		importCsv(files)
		'''

def main():
	if verifyArgv() == False:
		return
	if connectDB():
		restoreDB()
	
if __name__ == '__main__':
	main()